import pandas as pd
import yaml
import os
import matplotlib.pyplot as plt
from lidar_report import creat_report_main

plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
sensors = ['lidar','radar','obfu'] #,'camera']

table_names=['综合指标','位置误差','尺寸误差','速度误差','加速度误差']
tables = {}
MAE = {}
sum_index={}
for sensor in sensors:
    if sensor not in tables:
        tables[sensor]={}
    config_file = 'config/'+sensor+'_report_config.yaml'

    # creat_report_main(config_file)

    with open(config_file, 'r') as f:
        config = yaml.full_load(f)
    config['output_path'] = os.path.join(config['root_path'], config['output_path'])
    config['sum_data_path'] = os.path.join(config['output_path'], 'sum_data')
    table = pd.read_csv(os.path.join(config['sum_data_path'],'综合指标.csv'))
    num = pd.read_csv(os.path.join(config['sum_data_path'],'目标数.csv'))
    x1 =num.set_index('Unnamed: 0')
    x = table.set_index('Unnamed: 0')
    sum_index[sensor] = dict(x1.T.to_dict()[0],**x['all_class'].to_dict())

    for eeror in config['target_level_indicators']:
        if eeror not in table_names:
            continue
        table = pd.read_csv(os.path.join(config['sum_data_path'],eeror+'.csv'))
        for i in config['target_level_indicators'][eeror]:
            x=table[(table['Unnamed: 0']==i)&(table['Unnamed: 1']=='MAE')]
            x = x.set_index('Unnamed: 0')
            x = x.T[i].to_dict()
            x.pop('Unnamed: 1')
            if i not in MAE:
                MAE[i]={}
            MAE[i][sensor] = x
sum_table = pd.DataFrame(sum_index)
output = os.path.join(config['root_path'],'output')
if not os.path.exists(output):
    os.makedirs(output)
sum_table.to_csv(os.path.join(output,'综合指标.csv'))
for i in MAE:
    plt.plot(pd.DataFrame(MAE[i]))
    plt.legend(sensors)
    plt.title(i+'误差对比')
    plt.savefig(os.path.join(output,i+'.jpg'))
    plt.close()